#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@author: liang kang
@contact: gangkanli1219@gmail.com
@time: 2018/2/28 16:28
@desc: 这是一个基本模型模板类
"""
import tensorflow as tf


class BaseModel(object):
    """
    基本模型模板
    """
    def __init__(self, data_format='channels_last'):
        self.data_format = data_format

    def transpose(self, data):
        """
        转换数据格式

        Returns
        -------

        """
        if self.data_format == 'channels_last':
            return data
        else:
            return tf.transpose(data, [0, 3, 1, 2])

    def _init_model(self, *args, **kwargs):
        """
        重载这个方法，定义模型

        Returns
        -------

        """
        raise NotImplementedError

    def __call__(self, *args, **kwargs):
        """
        调用模型

        Parameters
        ----------
        args
        kwargs

        Returns
        -------

        """
        return self._init_model(*args, **kwargs)

